
import sys
import json
import pyiqa
import torch

# === load logistic parameters ===
import os
MODEL_DATA_PATH = '/root/IQA/IQA-Agent/iqa_models_results/model_fitting_result.json'
with open(MODEL_DATA_PATH, 'r') as f:
    MODEL_DATA_PARAMS = json.load(f)

def logistic(model_name, X):
    beta = torch.tensor(MODEL_DATA_PARAMS[model_name]['beta'], dtype=X.dtype, device=X.device)
    beta1, beta2, beta3, beta4, beta5 = beta
    logistic_part = 0.5 - 1.0 / (1 + torch.exp(beta2 * (X - beta3)))
    yhat = beta1 * logistic_part + beta4 * X + beta5
    return yhat

TOOL_NAME_TO_MODEL_NAME = {
    "TopIQ_FR_tool": "topiq_fr",
    "AHIQ_tool": "ahiq",
    "FSIM_tool": "fsim",
    "LPIPS_tool": "lpips",
    "DISTS_tool": "dists",
    "WaDIQaM_FR_tool": "wadiqam_fr",
    "PieAPP_tool": "pieapp",
    "MS_SSIM_tool": "ms_ssim",
    "GMSD_tool": "gmsd",
    "SSIM_tool": "ssim",
    "CKDN_tool": "ckdn",
    "VIF_tool": "vif",
    "PSNR_tool": "psnr",
    "VSI_tool": "vsi",
    "QAlign_tool": "qalign",
    "CLIPIQA_tool": "clipiqa+_rn50_512",
    "UNIQUE_tool": "unique",
    "HyperIQA_tool": "hyperiqa",
    "TReS_tool": "tres",
    "WaDIQaM_NR_tool": "wadiqam_nr",
    "ARNIQA_tool": "arniqa",
    "NIQE_tool": "niqe",
    "NIMA_tool": "nima",
    "BRISQUE_tool": "brisque",
    "MANIQA_tool": "maniqa",
    "LIQE_mix_tool": "liqe_mix"
}

def run_tool(tool_name, args):
    model_name = TOOL_NAME_TO_MODEL_NAME.get(tool_name)
    metric = pyiqa.create_metric(model_name, device='cuda')

    if "reference_image" in args and "distorted_image" in args:
        score = metric(args["distorted_image"], args["reference_image"])
    elif "image" in args:
        score = metric(args["image"])
    else:
        raise ValueError("Unsupported input format.")

    # Apply logistic mapping
    fitted_score = logistic(model_name, score)
    return fitted_score.item()

if __name__ == "__main__":
    tool_name = sys.argv[1]
    args_json = sys.argv[2]
    args = json.loads(args_json)
    try:
        result = run_tool(tool_name, args)
        print(json.dumps({"score": result}))
    except Exception as e:
        print(json.dumps({"error": str(e)}))
